Coverage for cpprb/util.py: 72%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import numpy as np
3from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
5def from_space(space,int_type,float_type):
6 if isinstance(space,Discrete):
7 return {"dtype": int_type,"shape": 1}
8 elif isinstance(space,MultiDiscrete): 8 ↛ 9line 8 didn't jump to line 9, because the condition on line 8 was never true
9 return {"dtype": int_type,"shape": space.nvec.shape}
10 elif isinstance(space,Box): 10 ↛ 12line 10 didn't jump to line 12, because the condition on line 10 was never false
11 return {"dtype": float_type,"shape": space.shape}
12 elif isinstance(space,MultiBinary):
13 return {"dtype": int_type, "shape": space.n}
14 else:
15 raise NotImplementedError(f"Error: Unknown Space {space}")
17def create_env_dict(env,*,int_type = None,float_type = None):
18 """
19 Create ``env_dict`` from Open AI ``gym.space`` for ``ReplayBuffer`` constructor
21 Parameters
22 ----------
23 env : gym.Env
24 Environment
25 int_type: np.dtype, optional
26 Integer type. Default is ``np.int32``
27 float_type: np.dtype, optional
28 Floating point type. Default is ``np.float32``
30 Returns
31 -------
32 env_dict : dict
33 ``env_dict`` parameter for ``ReplayBuffer`` class.
34 """
36 int_type = int_type or np.int32
37 float_type = float_type or np.float32
39 env_dict = {"rew" : {"shape": 1, "dtype": float_type},
40 "done": {"shape": 1, "dtype": float_type}}
42 observation_space = env.observation_space
43 action_space = env.action_space
45 if isinstance(observation_space,Tuple):
46 for i,s in enumerate(observation_space.spaces):
47 env_dict[f"obs{i}"] = from_space(s,int_type,float_type)
48 env_dict[f"next_obs{i}"] = from_space(s,int_type,float_type)
49 elif isinstance(observation_space,Dict): 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true
50 for n, s in observation_space.spaces.items():
51 env_dict[n] = from_space(s,int_type,float_type)
52 env_dict[f"next_{n}"] = from_space(s,int_type,float_type)
53 else:
54 env_dict["obs"] = from_space(observation_space,int_type,float_type)
55 env_dict["next_obs"] = from_space(observation_space,int_type,float_type)
57 if isinstance(action_space,Tuple):
58 for i,s in enumerate(action_space.spaces):
59 env_dict[f"act{i}"] = from_space(s,int_type,float_type)
60 elif isinstance(action_space,Dict): 60 ↛ 61line 60 didn't jump to line 61, because the condition on line 60 was never true
61 for n, s in action_space.spaces.items():
62 env_dict[n] = from_space(s,int_type,float_type)
63 else:
64 env_dict["act"] = from_space(action_space,int_type,float_type)
66 return env_dict
68def create_before_add_func(env):
69 """
70 Create function to be used before ``ReplayBuffer.add``
72 Parameters
73 ----------
74 env : gym.Env
75 Environment for before_func
77 Returns
78 -------
79 before_add : callable
80 Function to be used before ``ReplayBuffer.add``
81 """
82 def no_convert(name,v):
83 return {f"{name}": v}
85 def convert_from_tuple(name,_tuple):
86 return {f"{name}{i}": v for i,v in enumerate(_tuple)}
88 def convert_from_dict(name,_dict):
89 return {f"{name}_{key}":v for key,v in _dict.items()}
92 observation_space = env.observation_space
93 action_space = env.action_space
96 if isinstance(observation_space,Tuple):
97 obs_func = convert_from_tuple
98 elif isinstance(observation_space,Dict): 98 ↛ 99line 98 didn't jump to line 99, because the condition on line 98 was never true
99 obs_func = convert_from_dict
100 else:
101 obs_func = no_convert
103 if isinstance(action_space,Tuple):
104 act_func = convert_from_tuple
105 elif isinstance(action_space,Dict): 105 ↛ 106line 105 didn't jump to line 106, because the condition on line 105 was never true
106 act_func = convert_from_dict
107 else:
108 act_func = no_convert
110 def before_add(obs,act,next_obs,rew,done):
111 return {**obs_func("obs",obs),
112 **act_func("act",act),
113 **obs_func("next_obs",next_obs),
114 "rew": rew,
115 "done": done}
117 return before_add